import os
import struct
import numpy as np
import matplotlib.pyplot as plt


def load_mnist(path, kind='train'):
    """
    Load MNIST data from `path`
    参考链接：https://blog.csdn.net/simple_the_best/article/details/75267863
    """
    labels_path = os.path.join(path,
                               '%s-labels-idx1-ubyte'
                               % kind)
    images_path = os.path.join(path,
                               '%s-images-idx3-ubyte'
                               % kind)
    with open(labels_path, 'rb') as lbpath:
        magic, n = struct.unpack('>II',
                                 lbpath.read(8))
        labels = np.fromfile(lbpath,
                             dtype=np.uint8)

    with open(images_path, 'rb') as imgpath:
        magic, num, rows, cols = struct.unpack('>IIII',
                                               imgpath.read(16))
        images = np.fromfile(imgpath,
                             dtype=np.uint8).reshape(len(labels), 784)

    return images, labels


def plot_show(x,y):
    fig, ax = plt.subplots(
        nrows=2,
        ncols=5,
        sharex=True,
        sharey=True, )

    ax = ax.flatten()
    for i in range(10):
        img = x[y == i][0].reshape(28, 28)
        ax[i].imshow(img, cmap='Greys', interpolation='nearest')

    ax[0].set_xticks([])
    ax[0].set_yticks([])
    plt.tight_layout()
    plt.show()

def scatter_show(x,y):
    color_list = ["red","green","blue","black","pink","brown","gold","yellow","gray","aqua"]
    for i in range(10):
        plt.scatter([i]* len(x[y==i]), x[y == i], c=color_list[i], label=str(i))
    plt.legend(loc='best')
    plt.show()

def check(x,y):
    b = x[y==0]
    print(len(b[b>5]))
    print(len(b[b>4]))
    print(len(b[b<3]))

def plot_show2(x,y):
    color_list = ["red","green","blue","black","pink","brown","gold","yellow","gray","aqua"]
    for i in range(2):
        plt.figure()
        x_list = x[y==i]
        plt.plot(np.arange(len(x_list)),x_list,c =color_list[i],label = str(i))
        plt.legend(loc='best')
    plt.show()

def plot_show3(x,y):
    color_list = ["red","green","blue","black","pink","brown","gold","yellow","gray","aqua"]
    for i in range(3):
        x_list = x[y==i]
        plt.plot(np.arange(len(x_list)),x_list,c =color_list[i],label = str(i))
        plt.legend(loc='best')
    plt.show()

if __name__ == "__main__":
    X_train,y_train = load_mnist("./dataset/MNIST")
    # plot_show(X_train,y_train)
    X_2norm = np.linalg.norm(X_train,ord=2,axis = 1) / 784
    X_2norm = np.exp(X_2norm)
    # scatter_show(X_2norm,y_train)
    # check(X_2norm,y_train)
    plot_show3(X_2norm,y_train)